import math
from functools import partial

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Linear, Module
from torch.func import functional_call, vmap, grad
import einx
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from tensordict import TensorDict

from associative_scan import associative_scan, binary_operator, pad_at_dim


"""
ein notation:
b - batch （批次）
n - sequence （序列）
d - feature dimension （特征维度）
c - intra-chunk （块内维度）
"""


# 使用 partial 为 Linear 层创建一个不带偏置的版本
LinearNoBias = partial(Linear, bias = False)


def exists(v):
    """
    检查变量是否存在（不为 None）。

    参数:
        v (Any): 任意变量。

    返回:
        bool: 如果 v 不为 None，则返回 True，否则返回 False。
    """
    return v is not None


def default(v, d):
    """
    如果变量存在（不为 None），则返回变量本身；否则返回默认值。

    参数:
        v (Any): 任意变量。
        d (Any): 默认值。

    返回:
        Any: 如果 v 存在，则返回 v；否则返回 d。
    """
    return v if exists(v) else d


def identity(t):
    """
    返回输入张量本身。

    参数:
        t (Tensor): 输入张量。

    返回:
        Tensor: 输入张量。
    """
    return t


def round_down_multiple(seq, mult):
    """
    将序列长度向下取整到指定倍数的倍数。

    参数:
        seq (int): 序列长度。
        mult (int): 倍数。

    返回:
        int: 向下取整后的序列长度。
    """
    return seq // mult * mult


def round_up_multiple(seq, mult):
    """
    将序列长度向上取整到指定倍数的倍数。

    参数:
        seq (int): 序列长度。
        mult (int): 倍数。

    返回:
        int: 向上取整后的序列长度。
    """
    return math.ceil(seq / mult) * mult


def pack_one_with_inverse(t, pattern):
    """
    打包张量并返回用于解包的逆函数。

    参数:
        t (Tensor): 需要打包的张量。
        pattern (Tuple[int, ...]): 打包模式，指定每个维度如何分割。

    返回:
        Tuple[Tensor, Callable]: 打包后的张量和一个用于解包的函数。
    """
    packed, packed_shape = pack([t], pattern)

    def inverse(out, inv_pattern = None):
        """
        解包张量。

        参数:
            out (Tensor): 需要解包的张量。
            inv_pattern (Tuple[int, ...], 可选): 解包模式，默认为 None。如果为 None，则使用默认的打包模式。

        返回:
            Tensor: 解包后的张量。
        """
        inv_pattern = default(inv_pattern, pattern)
        return unpack(out, packed_shape, inv_pattern)[0]

    return packed, inverse


def softclamp_max(t, max_value):
    """
    对张量进行软裁剪，限制其最大值。

    参数:
        t (Tensor): 输入张量。
        max_value (float): 最大值。

    返回:
        Tensor: 软裁剪后的张量。
    """
    half_max_value = max_value / 2
    return ((t / half_max_value).tanh() * half_max_value) + half_max_value


def softclamp_grad_norm(t, max_value):
    """
    对梯度进行软裁剪，限制其范数。

    参数:
        t (Tensor): 输入张量。
        max_value (float): 最大范数。

    返回:
        Tensor: 软裁剪后的梯度。
    """
    # 打包张量，以便在解包时恢复原始形状
    t, inverse = pack_one_with_inverse(t, 'bn *')
    
    # 计算梯度的范数
    norm = t.norm(dim = -1, keepdim = True)
    # 对范数进行软裁剪
    clamped_norm = softclamp_max(norm, max_value)

    # 根据范数的比例调整梯度
    t = t * (clamped_norm / norm)
    # 解包张量，恢复原始形状
    return inverse(t)


class MultiheadRMSNorm(Module):
    """
    多头RMS归一化（Multihead RMSNorm）模块。

    该模块对输入张量应用RMS归一化，并使用多头参数对每个头进行缩放。
    """
    def __init__(self, dim, heads):
        """
        初始化多头RMS归一化模块。

        参数:
            dim (int): 特征维度。
            heads (int): 头的数量。
        """
        super().__init__()
        # 初始化RMS归一化层，不使用可学习的仿射参数
        self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
        # 初始化多头缩放参数，形状为 (heads, 1, dim)
        self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))

    def forward(self, x):
        """
        前向传播方法。

        参数:
            x (Tensor): 输入张量，形状为 (batch_size, ..., dim)。

        返回:
            Tensor: 归一化并缩放后的张量，形状与输入相同。
        """
        # 对输入张量应用RMS归一化
        # 将多头缩放参数与归一化后的张量相加，并进行缩放
        # gamma 的形状为 (heads, 1, dim)，通过广播机制与 normed 对齐
        return self.rmsnorm(x) * (self.gamma + 1.)


class MemoryMLP(Module):
    """
    记忆多层感知机（Memory MLP）模块。

    该模块由多个线性层组成，每个线性层后面跟随一个SiLU激活函数（除了第一个线性层）。
    """
    def __init__(
        self,
        dim,
        depth
    ):
        """
        初始化记忆MLP模块。

        参数:
            dim (int): 输入和输出的特征维度。
            depth (int): MLP的深度，即线性层的数量。
        """
        super().__init__()
        # 初始化参数列表，每个参数是一个线性层的权重矩阵，形状为 (dim, dim)
        self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])

    def forward(
        self,
        x
    ):
        """
        前向传播方法。

        参数:
            x (Tensor): 输入张量，形状为 (batch_size, ..., dim)。

        返回:
            Tensor: MLP的输出，形状与输入相同。
        """
        for ind, weight in enumerate(self.weights):
            # 判断是否是第一个线性层
            is_first = ind == 0

            if not is_first:
                # 如果不是第一个线性层，则应用SiLU激活函数
                x = F.silu(x)

            # 应用线性层
            x = x @ weight

        return x


def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
    """
    默认的自适应步长转换函数。

    将自适应步长转换为学习率，范围从0到max_lr。

    参数:
        adaptive_step (Tensor): 自适应步长张量。
        max_lr (float, 可选): 最大学习率，默认为1e-2。

    返回:
        Tensor: 转换后的学习率张量。
    """
    return adaptive_step.sigmoid() * max_lr


def default_loss_fn(pred, target):
    """
    默认的损失函数。

    计算预测值与目标值之间的均方误差（MSE）。

    参数:
        pred (Tensor): 预测值张量。
        target (Tensor): 目标值张量。

    返回:
        Tensor: 计算得到的损失值。
    """
    return (pred - target).pow(2).mean(dim = -1)


class NeuralMemory(Module):
    """
    神经记忆模块（Neural Memory Module）。

    该模块实现了神经记忆机制，通过记忆模型存储和检索信息，并在训练过程中动态调整学习率和动量。
    """
    def __init__(
        self,
        dim,
        chunk_size = 1,
        dim_head = None,
        heads = 1,
        model: Module | None = None,
        store_memory_loss_fn = default_loss_fn,
        adaptive_step_transform = default_adaptive_step_transform,
        pre_rmsnorm = True,
        post_rmsnorm = True,
        max_grad_norm: float | None = None,
        use_accelerated_scan = False,
        default_mlp_kwargs: dict = dict(
            depth = 2
        )
    ):
        """
        初始化神经记忆模块。

        参数:
            dim (int): 特征维度。
            chunk_size (int, 可选): 块大小，默认为1。
            dim_head (int, 可选): 每个注意力头的维度，默认为 None。如果为 None，则使用 `dim`。
            heads (int, 可选): 注意力头的数量，默认为1。
            model (Module, 可选): 记忆模型，默认为 None。如果为 None，则使用默认的 `MemoryMLP` 模型。
            store_memory_loss_fn (Callable[[Tensor, Tensor], Tensor], 可选): 存储记忆时的损失函数，默认为默认的损失函数。
            adaptive_step_transform (Callable[[Tensor], Tensor], 可选): 自适应步长转换函数，默认为默认的转换函数。
            pre_rmsnorm (bool, 可选): 是否在存储前应用RMS归一化，默认为 True。
            post_rmsnorm (bool, 可选): 是否在存储后应用RMS归一化，默认为 True。
            max_grad_norm (float, 可选): 存储记忆时的最大梯度范数，默认为 None。
            use_accelerated_scan (bool, 可选): 是否使用加速扫描，默认为 False。
            default_mlp_kwargs (Dict[str, Any], 可选): 默认的MLP参数，默认为深度为2。
        """
        super().__init__()
        # 如果未指定每个头的维度，则使用特征维度
        dim_head = default(dim_head, dim)

        # norms
        # 归一化层
        # 检索前的RMS归一化
        self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
        # 存储前的RMS归一化
        self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
        # 存储后的多头RMS归一化
        self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()

        # maybe multi-headed
        # 多头处理
        # 计算内部特征维度
        dim_inner = dim_head * heads

        # 保存注意力头的数量
        self.heads = heads

        # 将批次和头维度合并
        self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
        # 将头和批次维度分开
        self.merge_heads = Rearrange('b h n d -> b n (h d)')
        # 如果有多个头，则使用线性层合并头；否则，使用恒等函数
        self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()

        self.retrieve_gate = nn.Sequential(
            LinearNoBias(dim, heads),  # 线性层，将特征维度映射到头的数量
            Rearrange('b n h -> b h n 1'),  # 重塑张量形状
            nn.Sigmoid()  # 应用Sigmoid激活函数
        ) if heads > 1 else None  # 如果只有一个头，则不需要门控机制

        # memory mlp
        # 记忆模型
        if not exists(model):
            # 如果未提供记忆模型，则使用默认的 `MemoryMLP` 模型
            model = MemoryMLP(dim_head, **default_mlp_kwargs)

        assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'

        # the memory is the weights of the model
        # 保存记忆模型
        self.memory_model = model

        # the chunk size within the paper where adaptive step, momentum, weight decay are shared
        # 保存块大小
        self.chunk_size = chunk_size

        # prepare function for per sample gradients from model above, using torch.func
        # 准备用于计算每个样本梯度的函数，使用 torch.func
        def forward_and_loss(params, inputs, loss_weights, target):
            # 使用记忆模型进行前向传播
            pred = functional_call(self.memory_model, params, inputs)
            # 计算损失，默认为均方误差
            loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
            # 乘以损失权重
            loss = loss * loss_weights
            return loss.sum()

        # 对每个样本计算梯度
        self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))

        # queries for retrieving from the model
        # 查询函数，用于从模型中检索信息
        self.to_queries = LinearNoBias(dim, dim_inner) # 线性层，将特征维度映射到内部特征维度

        # keys and values for storing to the model
        # 键和值函数，用于向模型中存储信息
        self.to_keys_values = LinearNoBias(dim, dim_inner * 2)  # 线性层，将特征维度映射到键和值维度
        self.store_memory_loss_fn = store_memory_loss_fn  # 保存存储记忆时的损失函数

        # empty memory embed
        # 空记忆嵌入
        # 初始化空记忆嵌入为全零张量
        self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
        # 使用正态分布初始化空记忆嵌入
        nn.init.normal_(self.empty_memory_embed, std = 0.02)

        # learned adaptive learning rate and momentum
        # todo - explore mlp layerwise learned lr / momentum
        # 学习到的自适应学习率和动量
        self.to_momentum = nn.Sequential(
            Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),  # 对块内的特征进行平均
            LinearNoBias(dim, heads),  # 线性层，将特征维度映射到头的数量
            Rearrange('b n h -> (b h) n 1')  # 重塑张量形状
        )

        self.to_adaptive_step = nn.Sequential(
            LinearNoBias(dim, heads),  # 线性层，将特征维度映射到头的数量
            Rearrange('b n h -> (b h) n')  # 重塑张量形状
        )

        # 保存自适应步长转换函数
        self.adaptive_step_transform = adaptive_step_transform

        # allow for softclamp the gradient norms for storing memories
        # 允许对存储记忆时的梯度范数进行软裁剪
        self.max_grad_norm = max_grad_norm

        # weight decay factor
        # 权重衰减因子
        self.to_decay_factor = nn.Sequential(
            Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),  # 对块内的特征进行平均
            LinearNoBias(dim, heads),  # 线性层，将特征维度映射到头的数量
            Rearrange('b n h -> (b h) n 1')  # 重塑张量形状
        )

        # maybe use accelerated scan
        # 是否使用加速扫描
        self.use_accelerated_scan = use_accelerated_scan

    def init_weights_and_momentum(self):
        """
        初始化记忆模型的权重和动量。

        返回:
            Tuple[TensorDict, TensorDict]: 初始化的权重和动量，分别为 TensorDict 对象。
        """
        # 获取记忆模型的所有参数，并将其转换为 TensorDict 对象
        params = TensorDict(dict(self.memory_model.named_parameters()))

        # 初始化权重为零张量
        init_weights = params.clone().zero_()
        # 初始化动量为零张量
        init_momentum = params.clone().zero_()

        # 返回初始化的权重和动量
        return init_weights, init_momentum

    def init_empty_memory_embed(self, batch, seq_len):
        """
        初始化空记忆嵌入。

        参数:
            batch (int): 批次大小。
            seq_len (int): 序列长度。

        返回:
            Tensor: 初始化后的空记忆嵌入，形状为 (batch, seq_len, dim)。
        """
        # 重复空记忆嵌入，生成形状为 (batch, seq_len, dim) 的张量
        return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)

    def store_memories(
        self,
        seq,
        past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
    ):
        """
        存储记忆并更新记忆模型的权重和动量。

        参数:
            seq (Tensor): 输入序列，形状为 (batch, seq_len, dim)。
            past_state (Tuple[Dict[str, Tensor], Dict[str, Tensor]]): 过去的状态，包含权重和动量。

        返回:
            Tuple[Dict[str, Tensor], Tuple[Dict[str, Tensor], Dict[str, Tensor]]]: 更新后的权重和动量，以及新的状态。
        """
        # 对输入序列应用存储前的归一化
        seq = self.store_norm(seq)

        # curtail sequence by multiple of the chunk size
        # only a complete chunk of the sequence provides the memory for the next chunk
        # 计算序列长度和块大小
        seq_len, chunk_size = seq.shape[-2], self.chunk_size
        # 将序列长度向下取整到块大小的倍数，确保每个块完整
        round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)

        # 截断序列，使其长度为块大小的倍数
        seq = seq[:, :round_down_seq_len]

        # curr weights + past weights, in the case that the initial weights are learned
        # 获取当前记忆模型的权重
        curr_weights = TensorDict(dict(self.memory_model.named_parameters()))

        # 将过去的状态转换为 TensorDict 对象
        past_state = tuple(TensorDict(d) for d in past_state)
        past_weights, past_momentum = past_state

        # 将当前权重与过去权重相加
        curr_weights = curr_weights + past_weights

        # pack batch and sequence dimension
        # 计算自适应学习率：
        # 对输入序列应用自适应步长模块（to_adaptive_step），然后应用自适应步长转换函数（adaptive_step_transform）
        adaptive_lr = self.to_adaptive_step(seq)
        adaptive_lr = self.adaptive_step_transform(adaptive_lr)

        # 计算自适应动量：
        # 对输入序列应用动量模块（to_momentum），然后使用 sigmoid 函数将其值压缩到 (0, 1) 之间。
        adaptive_momentum = self.to_momentum(seq).sigmoid()

        # 计算权重衰减因子：
        # 对输入序列应用衰减因子模块（to_decay_factor），然后使用 sigmoid 函数将其值压缩到 (0, 1) 之间。
        decay_factor = self.to_decay_factor(seq).sigmoid()

        # keys and values
        # 分离键和值：
        # 对输入序列应用键值模块（to_keys_values），然后将其在最后一个维度上分割成两部分，分别作为键和值。
        keys, values = self.to_keys_values(seq).chunk(2, dim = -1)

        # maybe multi head
        # 处理多头：
        # 对键和值应用多头重塑（split_heads），将批次和头数维度合并。
        keys, values = map(self.split_heads, (keys, values))

        # 获取批次大小
        batch = keys.shape[0]

        # take care of chunking
        # 处理块：
        # 将键和值在序列维度上重塑为 (batch * n, c, d)，其中 c 是块内维度，d 是特征维度。
        keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))

        # 重塑自适应学习率：
        # 将自适应学习率重塑为 (batch * n, c)，以便与键和值对齐。
        adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)

        # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
        # 计算梯度并计算辅助损失：
        # 使用 per_sample_grad_fn 计算每个样本的梯度，传入当前权重、键、自适应学习率和值。
        grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)

        grads = TensorDict(grads)

        # maybe softclamp grad norm
        # 如果存在最大梯度范数，则对梯度进行软裁剪
        if exists(self.max_grad_norm):
            grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))

        # restore batch and sequence dimension
        # 恢复批次和序列维度：
        # 将梯度张量从 (batch * n, ...) 重塑为 (batch, n, ...)。
        grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))

        # negative gradients, adaptive lr already applied as loss weight
        # 计算惊喜（surprises）：
        # 将梯度取负数，因为梯度下降需要负梯度。
        surprises = grads.apply(lambda t: -t)

        # determine scan function
        # 定义默认的关联扫描函数：
        # 使用 associative_scan 和 binary_operator 对输入的 gates 和 inputs 进行扫描。
        def default_associative_scan(gates, inputs):
            _, outputs = associative_scan(binary_operator, (gates, inputs))
            return outputs

         # 如果使用加速扫描：
        if self.use_accelerated_scan:
            from accelerated_scan.triton import scan as triton_scan
            from accelerated_scan.warp import scan as warp_scan

            scan = triton_scan if seq.is_cuda else warp_scan

            # 定义加速扫描函数：
            # 1. 对 gates 和 inputs 进行扩展和重塑。
            # 2. 对序列长度进行填充，使其为2的幂。
            # 3. 调用扫描函数。
            # 4. 截取填充后的结果，并恢复原始形状。
            def accelerate_scan_fn(gates, inputs):
                gates = gates.expand_as(inputs)
                gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))

                seq_len = gates.shape[-1]
                next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))

                gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
                inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))

                outputs = scan(gates.contiguous(), inputs.contiguous())

                outputs = outputs[..., :seq_len]
                outputs = rearrange(outputs, 'b d n -> b n d')
                return outputs

            scan_fn = accelerate_scan_fn
        else:
            scan_fn = default_associative_scan

        # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
        # 计算动量和更新：
        # 1. 对每个参数名和对应的惊喜（surprise）进行迭代。
        # 2. 使用 pack_one_with_inverse 对惊喜进行打包，并获取逆函数。
        # 3. 使用 scan_fn 计算动量。
        # 4. 再次使用 scan_fn 计算更新（考虑权重衰减）。
        # 5. 将更新和动量逆打包，并存储到 updates 和 next_momentum 中。
        next_momentum = TensorDict()
        updates = TensorDict()

        for param_name, surprise in surprises.items():

            surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')

            # derive momentum with associative scan - eq (10)
            # 计算动量：
            # 使用关联扫描函数，根据自适应动量和惊喜计算动量。
            momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper

            # use associative scan again for learned forgetting (weight decay) - eq (13)
            # 计算更新：
            # 使用关联扫描函数，根据权重衰减因子和动量计算更新。
            update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper

            updates[param_name] = inverse_pack(update)
            next_momentum[param_name] = inverse_pack(momentum)

        # compute the next weight per batch
        # 计算每个批次的下一个权重：
        # 对每个参数，获取最后一个更新，并将其添加到当前权重中。
        last_update = updates.apply(lambda t: t[:, -1])

        next_state = (curr_weights + last_update, next_momentum)

        return updates, next_state

    def retrieve_memories(
        self,
        seq,
        past_weights: dict[str, Tensor] | None = None,
    ):
        """
        从记忆中检索信息。

        参数:
            seq (Tensor): 输入序列，形状为 (batch, seq_len, dim)。
            past_weights (Dict[str, Tensor], 可选): 过去的权重，默认为 None。

        返回:
            Tensor: 检索到的记忆，形状为 (batch, seq_len + chunk_size - 1, dim)。
        """
        # 获取块大小
        chunk_size = self.chunk_size
        # 获取批次大小和序列长度
        batch, seq_len = seq.shape[:2]

        # 对输入序列应用检索前的归一化
        seq = self.retrieve_norm(seq)

        assert seq_len >= chunk_size

        # 截取序列，从第 (chunk_size - 1) 个时间步开始
        seq = seq[:, (chunk_size - 1):]
        # 获取截取后的序列长度
        curtailed_seq_len = seq.shape[-2]

        # 计算下一个序列长度，向上取整到块大小的倍数
        next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)

        # 计算需要填充的长度
        padding = next_seq_len - curtailed_seq_len

        # 判断是否需要填充
        needs_pad = padding > 0

        if needs_pad:
            # 如果需要填充，则在序列维度上填充，填充值为0
            seq = pad_at_dim(seq, (0, padding), dim = 1)

        # the parameters of the memory model stores the memories of the key / values
        # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
        # 获取当前记忆模型的权重
        curr_weights = TensorDict(dict(self.memory_model.named_parameters()))

        if exists(past_weights):
            # 如果存在过去权重，则将其转换为 TensorDict 对象，并断言键与当前权重一致
            past_weights = TensorDict(past_weights)
            assert past_weights.keys() == curr_weights.keys()
            
            # 将当前权重与过去权重相加
            curr_weights = curr_weights + past_weights

        # sequence Float['b n d'] to queries
        # 将序列从 Float['b n d'] 转换为查询
        queries = self.to_queries(seq)

        # maybe multihead
        # 处理多头
        queries = self.split_heads(queries)

        # fetch values from memory model
        # 重塑权重张量形状为 (batch * n, ...)，以便与查询对齐
        curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
        # 重塑查询张量形状为 (batch * n, c, d)，其中 c 是块内维度
        queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)

        # forward functional call
        # 使用记忆模型进行前向传播，获取值
        values = functional_call(self.memory_model, dict(curr_weights), queries)

        # reconstitute batch dimension
        # 恢复批次和头的维度，形状为 (batch, heads, n * c, d)
        values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)

        # 应用多头RMS归一化
        values = self.multihead_rmsnorm(values)

        # maybe gate
        # 如果存在检索门控机制，则应用门控
        if exists(self.retrieve_gate):
            values = values * self.retrieve_gate(seq)

        # maybe merge heads and combine
        # 合并多头
        values = self.merge_heads(values)

        # 组合多头
        values = self.combine_heads(values)

        # restore, pad with empty memory embed
        # 恢复填充：
        # 初始化空记忆嵌入，形状为 (batch, chunk_size - 1, dim)
        empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
        # 将空记忆嵌入与检索到的记忆连接起来，形状为 (batch, chunk_size, dim)
        values = torch.cat((empty_memory_embeds, values), dim = -2)

        if needs_pad:
            # 如果之前进行了填充，则去除末尾的填充部分
            values = values[:, :-padding]

        # 返回检索到的记忆
        return values

    def forward(
        self,
        seq,
        store_seq = None,
        past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
        return_next_memories = False
    ):
        """
        前向传播方法。

        该方法实现了记忆的存储、检索以及更新过程。根据输入序列和过去状态，模型可以存储新的记忆，检索现有的记忆，并返回当前或下一个记忆状态。

        参数:
            seq (Tensor): 输入序列，形状为 (batch, seq_len, dim)。
                - `batch`: 批次大小。
                - `seq_len`: 序列长度。
                - `dim`: 特征维度。
            store_seq (Tensor, 可选): 用于存储的序列，默认为 None。
                - 如果为 None，则使用输入序列 `seq` 进行记忆存储。
            past_state (Tuple[Dict[str, Tensor], Dict[str, Tensor]], 可选): 过去的状态，包含权重和动量，默认为 None。
                - 第一个字典包含过去的权重。
                - 第二个字典包含过去的动量。
            return_next_memories (bool, 可选): 是否返回下一个记忆状态，默认为 False。
                - 如果为 True，则返回更新后的权重和动量。
                - 如果为 False，则仅返回检索到的记忆。

        返回:
            Tuple[Tensor, Optional[Tuple[Dict[str, Tensor], Dict[str, Tensor]]]]:
                - 如果 `return_next_memories` 为 False，则返回检索到的记忆，形状为 (batch, seq_len + chunk_size - 1, dim)。
                - 如果 `return_next_memories` 为 True，则返回一个包含检索到的记忆和下一个记忆状态的元组。
        """
        # 获取输入序列的批次大小和序列长度
        batch, seq_len = seq.shape[:2]

        if seq_len < self.chunk_size:
            # 如果序列长度小于块大小，则返回初始化后的空记忆嵌入
            return self.init_empty_memory_embed(batch, seq_len)

        if exists(past_state):
            # 如果存在过去状态，则将其转换为 TensorDict 对象
            past_state = tuple(TensorDict(d) for d in past_state)

        if not exists(past_state):
            # 如果不存在过去状态，则初始化权重和动量
            past_state = self.init_weights_and_momentum()

        # 如果未提供存储序列，则使用输入序列
        store_seq = default(store_seq, seq)

        # 存储记忆并获取更新和下一个记忆状态
        updates, next_memories = self.store_memories(store_seq, past_state)

        # 获取过去的权重
        past_weights, _ = past_state

        # 检索记忆：使用过去的权重和更新进行检索
        retrieved = self.retrieve_memories(seq, past_weights + updates)

        if not return_next_memories:
            # 如果不需要返回下一个记忆状态，则返回检索到的记忆
            return retrieved

        # 如果需要返回下一个记忆状态，则返回检索到的记忆和下一个记忆状态
        return retrieved, next_memories
